{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 이전 대화를 기억하는 Chain 생성방법\n",
    "\n",
    "이 내용을 이해하기 위한 사전 지식\n",
    "- `RunnableWithMessageHistory`: [https://wikidocs.net/235581](https://wikidocs.net/235581)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# API KEY를 환경변수로 관리하기 위한 설정 파일\n",
    "from dotenv import load_dotenv\n",
    "\n",
    "# API KEY 정보로드\n",
    "load_dotenv()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LangSmith 추적을 시작합니다.\n",
      "[프로젝트명]\n",
      "CH12-RAG\n"
     ]
    }
   ],
   "source": [
    "# LangSmith 추적을 설정합니다. https://smith.langchain.com\n",
    "# !pip install langchain-teddynote\n",
    "from langchain_teddynote import logging\n",
    "\n",
    "# 프로젝트 이름을 입력합니다.\n",
    "logging.langsmith(\"CH12-RAG\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. 일반 Chain 에 대화기록 추가"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder\n",
    "from langchain_community.chat_message_histories import ChatMessageHistory\n",
    "from langchain_core.chat_history import BaseChatMessageHistory\n",
    "from langchain_core.runnables.history import RunnableWithMessageHistory\n",
    "from langchain_openai import ChatOpenAI\n",
    "from langchain_core.output_parsers import StrOutputParser\n",
    "\n",
    "\n",
    "# 프롬프트 정의\n",
    "prompt = ChatPromptTemplate.from_messages(\n",
    "    [\n",
    "        (\n",
    "            \"system\",\n",
    "            \"당신은 Question-Answering 챗봇입니다. 주어진 질문에 대한 답변을 제공해주세요.\",\n",
    "        ),\n",
    "        # 대화기록용 key 인 chat_history 는 가급적 변경 없이 사용하세요!\n",
    "        MessagesPlaceholder(variable_name=\"chat_history\"),\n",
    "        (\"human\", \"#Question:\\n{question}\"),  # 사용자 입력을 변수로 사용\n",
    "    ]\n",
    ")\n",
    "\n",
    "# llm 생성\n",
    "llm = ChatOpenAI()\n",
    "\n",
    "# 일반 Chain 생성\n",
    "chain = prompt | llm | StrOutputParser()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "대화를 기록하는 체인 생성(`chain_with_history`)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 세션 기록을 저장할 딕셔너리\n",
    "store = {}\n",
    "\n",
    "\n",
    "# 세션 ID를 기반으로 세션 기록을 가져오는 함수\n",
    "def get_session_history(session_ids):\n",
    "    print(f\"[대화 세션ID]: {session_ids}\")\n",
    "    if session_ids not in store:  # 세션 ID가 store에 없는 경우\n",
    "        # 새로운 ChatMessageHistory 객체를 생성하여 store에 저장\n",
    "        store[session_ids] = ChatMessageHistory()\n",
    "    return store[session_ids]  # 해당 세션 ID에 대한 세션 기록 반환\n",
    "\n",
    "\n",
    "chain_with_history = RunnableWithMessageHistory(\n",
    "    chain,\n",
    "    get_session_history,  # 세션 기록을 가져오는 함수\n",
    "    input_messages_key=\"question\",  # 사용자의 질문이 템플릿 변수에 들어갈 key\n",
    "    history_messages_key=\"chat_history\",  # 기록 메시지의 키\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "첫 번째 질문 실행"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[대화 세션ID]: abc123\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'안녕하세요, 테디님. 무엇을 도와드릴까요?'"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "chain_with_history.invoke(\n",
    "    # 질문 입력\n",
    "    {\"question\": \"나의 이름은 테디입니다.\"},\n",
    "    # 세션 ID 기준으로 대화를 기록합니다.\n",
    "    config={\"configurable\": {\"session_id\": \"abc123\"}},\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "이어서 질문 실행"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[대화 세션ID]: abc123\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'당신의 이름은 테디입니다.'"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "chain_with_history.invoke(\n",
    "    # 질문 입력\n",
    "    {\"question\": \"내 이름이 뭐라고?\"},\n",
    "    # 세션 ID 기준으로 대화를 기록합니다.\n",
    "    config={\"configurable\": {\"session_id\": \"abc123\"}},\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. RAG + RunnableWithMessageHistory"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "먼저 일반 RAG Chain 을 생성합니다. 단, 6단계의 prompt 에 `{chat_history}` 를 꼭 추가합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_text_splitters import RecursiveCharacterTextSplitter\n",
    "from langchain_community.document_loaders import PDFPlumberLoader\n",
    "from langchain_community.vectorstores import FAISS\n",
    "from langchain_core.output_parsers import StrOutputParser\n",
    "from langchain_core.prompts import PromptTemplate\n",
    "from langchain_openai import ChatOpenAI, OpenAIEmbeddings\n",
    "from langchain_community.chat_message_histories import ChatMessageHistory\n",
    "from langchain_core.runnables.history import RunnableWithMessageHistory\n",
    "from langchain_openai import ChatOpenAI\n",
    "from langchain_core.output_parsers import StrOutputParser\n",
    "from operator import itemgetter\n",
    "\n",
    "# 단계 1: 문서 로드(Load Documents)\n",
    "loader = PDFPlumberLoader(\"data/SPRI_AI_Brief_2023년12월호_F.pdf\")\n",
    "docs = loader.load()\n",
    "\n",
    "# 단계 2: 문서 분할(Split Documents)\n",
    "text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=50)\n",
    "split_documents = text_splitter.split_documents(docs)\n",
    "\n",
    "# 단계 3: 임베딩(Embedding) 생성\n",
    "embeddings = OpenAIEmbeddings()\n",
    "\n",
    "# 단계 4: DB 생성(Create DB) 및 저장\n",
    "# 벡터스토어를 생성합니다.\n",
    "vectorstore = FAISS.from_documents(documents=split_documents, embedding=embeddings)\n",
    "\n",
    "# 단계 5: 검색기(Retriever) 생성\n",
    "# 문서에 포함되어 있는 정보를 검색하고 생성합니다.\n",
    "retriever = vectorstore.as_retriever()\n",
    "\n",
    "# 단계 6: 프롬프트 생성(Create Prompt)\n",
    "# 프롬프트를 생성합니다.\n",
    "prompt = PromptTemplate.from_template(\n",
    "    \"\"\"You are an assistant for question-answering tasks. \n",
    "Use the following pieces of retrieved context to answer the question. \n",
    "If you don't know the answer, just say that you don't know. \n",
    "Answer in Korean.\n",
    "\n",
    "#Previous Chat History:\n",
    "{chat_history}\n",
    "\n",
    "#Question: \n",
    "{question} \n",
    "\n",
    "#Context: \n",
    "{context} \n",
    "\n",
    "#Answer:\"\"\"\n",
    ")\n",
    "\n",
    "# 단계 7: 언어모델(LLM) 생성\n",
    "# 모델(LLM) 을 생성합니다.\n",
    "llm = ChatOpenAI(model_name=\"gpt-4.1-mini\", temperature=0)\n",
    "\n",
    "# 단계 8: 체인(Chain) 생성\n",
    "chain = (\n",
    "    {\n",
    "        \"context\": itemgetter(\"question\") | retriever,\n",
    "        \"question\": itemgetter(\"question\"),\n",
    "        \"chat_history\": itemgetter(\"chat_history\"),\n",
    "    }\n",
    "    | prompt\n",
    "    | llm\n",
    "    | StrOutputParser()\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "대화를 저장할 함수 정의"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 세션 기록을 저장할 딕셔너리\n",
    "store = {}\n",
    "\n",
    "\n",
    "# 세션 ID를 기반으로 세션 기록을 가져오는 함수\n",
    "def get_session_history(session_ids):\n",
    "    print(f\"[대화 세션ID]: {session_ids}\")\n",
    "    if session_ids not in store:  # 세션 ID가 store에 없는 경우\n",
    "        # 새로운 ChatMessageHistory 객체를 생성하여 store에 저장\n",
    "        store[session_ids] = ChatMessageHistory()\n",
    "    return store[session_ids]  # 해당 세션 ID에 대한 세션 기록 반환\n",
    "\n",
    "\n",
    "# 대화를 기록하는 RAG 체인 생성\n",
    "rag_with_history = RunnableWithMessageHistory(\n",
    "    chain,\n",
    "    get_session_history,  # 세션 기록을 가져오는 함수\n",
    "    input_messages_key=\"question\",  # 사용자의 질문이 템플릿 변수에 들어갈 key\n",
    "    history_messages_key=\"chat_history\",  # 기록 메시지의 키\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "첫 번째 질문 실행"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[대화 세션ID]: rag123\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "\"삼성전자가 만든 생성형 AI의 이름은 '삼성 가우스'입니다.\""
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rag_with_history.invoke(\n",
    "    # 질문 입력\n",
    "    {\"question\": \"삼성전자가 만든 생성형 AI 이름은?\"},\n",
    "    # 세션 ID 기준으로 대화를 기록합니다.\n",
    "    config={\"configurable\": {\"session_id\": \"rag123\"}},\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "이어진 질문 실행"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[대화 세션ID]: rag123\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "\"The name of the generative AI created by Samsung Electronics is 'Samsung Gauss'.\""
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rag_with_history.invoke(\n",
    "    # 질문 입력\n",
    "    {\"question\": \"이전 답변을 영어로 번역해주세요.\"},\n",
    "    # 세션 ID 기준으로 대화를 기록합니다.\n",
    "    config={\"configurable\": {\"session_id\": \"rag123\"}},\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "langchain-kr-lwwSZlnu-py3.11",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
